# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import random
random.seed(0)
from typing import List

import torch
from transformers import LlamaForCausalLM, LlamaTokenizer
import hyperparse
from wm.utools import calculate_checksum

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class WmGenerator():
    def __init__(self, 
            model: LlamaForCausalLM, 
            tokenizer: LlamaTokenizer, 
            ngram: int = 1,
            seed: int = 0,
            seeding: str = 'hash',
            salt_key: int = 35317,
            payload: int = 0,
            args = None,
        ):
        # model config
        self.tokenizer = tokenizer
        self.model = model
        self.max_seq_len = model.config.max_sequence_length
        self.pad_id = model.config.pad_token_id
        self.eos_id = model.config.eos_token_id
        # watermark config
        self.ngram = ngram
        self.salt_key = salt_key
        self.seed = seed
        self.hashtable = torch.randperm(1000003)
        self.seeding = seeding 
        self.rng = torch.Generator()
        self.rng.manual_seed(self.seed)
        self.payload = payload
        self.usermode, self.usermode_str = hyperparse.parse("usermode")
        print(self.usermode)
        self.args = args
        if "mkey" in self.usermode:
            if self.usermode["mkey"] is None:
                self.key01 = "10"
            elif type(self.usermode["mkey"]) is int:
                self.key01 = str(self.usermode["mkey"])     
        if "rkey" in self.usermode:
            self.key01 = "0" * int(self.usermode["rkey"])
        if "mkey" in self.usermode or "rkey" in self.usermode:
            self.anchor = ("10"*10)[:args.nanchor]
            self.key01 =  self.add_anchorcheck(self.key01)
            print("key01: ", self.key01)
        if "chars" in self.usermode:
            self.chars = [ord(char) for char in self.usermode["chars"]]

    def add_anchorcheck(self, key01):
        return self.anchor + calculate_checksum(key01, parity_bits_count = self.args.nchecksum) + key01

    def hashint(self, integer_tensor: torch.LongTensor) -> torch.LongTensor:
        """Adapted from https://github.com/jwkirchenbauer/lm-watermarking"""
        return self.hashtable[integer_tensor.cpu() % len(self.hashtable)] 
    
    def get_seed_rng(
        self, 
        input_ids: torch.LongTensor,
        salt_key = None
    ) -> int:
        """
        Seed RNG with hash of input_ids.
        Adapted from https://github.com/jwkirchenbauer/lm-watermarking
        """
        salt_key = salt_key if salt_key is not None else self.salt_key
        if self.seeding == 'hash':
            seed = self.seed
            for i in input_ids:
                seed = (seed * salt_key + i.item()) % (2 ** 64 - 1)
        elif self.seeding == 'additive':
            seed = salt_key * torch.sum(input_ids).item()
            seed = self.hashint(seed)
        elif self.seeding == 'skip':
            seed = salt_key * input_ids[0].item()
            seed = self.hashint(seed)
        elif self.seeding == 'min':
            seed = self.hashint(salt_key * input_ids)
            seed = torch.min(seed).item()
        return seed

    @torch.no_grad()
    def generate(
        self,
        prompts: List[str],
        max_gen_len: int,
        temperature: float = 0.8,
        top_p: float = 0.95,
    ) -> List[str]:
        """
        Generate text from prompts. 
        Adapted from https://github.com/facebookresearch/llama/
        """
        
        bsz = len(prompts)
        prompt_tokens = [self.tokenizer.encode(x, add_special_tokens=False) for x in prompts]
        min_prompt_size = min([len(t) for t in prompt_tokens])
        max_prompt_size = max([len(t) for t in prompt_tokens])
        total_len = min(self.max_seq_len, max_gen_len + max_prompt_size)

        tokens = torch.full((bsz, total_len), self.pad_id).to(device).long()
        for k, t in enumerate(prompt_tokens):
            tokens[k, : len(t)] = torch.tensor(t).long()
        input_text_mask = tokens != self.pad_id

        start_pos = min_prompt_size
        prev_pos = 0
        for cur_pos in range(start_pos, total_len):
            outputs = self.model.forward(
                tokens[:, prev_pos:cur_pos], use_cache=True, past_key_values=outputs.past_key_values if prev_pos > 0 else None
            )
            if "posenc" in self.usermode:
                self.cur_pos = cur_pos - start_pos - self.ngram - 1
            ngram_tokens = tokens[:, cur_pos-self.ngram:cur_pos]
            next_toks = self.sample_next(outputs.logits[:, -1, :], ngram_tokens, temperature, top_p)
            tokens[:, cur_pos] = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_toks)
            prev_pos = cur_pos

        decoded = []
        for i, t in enumerate(tokens.tolist()):
            # cut to max gen len
            t = t[: len(prompt_tokens[i]) + max_gen_len]
            # cut to eos tok if any
            try:
                t = t[: t.index(self.eos_id)]
            except ValueError:
                pass
            decoded.append(self.tokenizer.decode(t))

        return decoded
    
    def sample_next(
        self,
        logits: torch.FloatTensor, # (bsz, vocab_size): logits for last token
        ngram_tokens: torch.LongTensor, # (bsz, ngram): tokens to consider when seeding
        temperature: float = 0.8, # temperature for sampling
        top_p: float = 0.95, # top p for sampling
    ) -> torch.LongTensor:
        """ Vanilla sampling with temperature and top p."""
        if temperature > 0:
            probs = torch.softmax(logits / temperature, dim=-1)
            probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
            probs_sum = torch.cumsum(probs_sort, dim=-1)
            mask = probs_sum - probs_sort > top_p
            probs_sort[mask] = 0.0
            probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
            next_token = torch.multinomial(probs_sort, num_samples=1) # one hot of next token, ordered by original probs
            next_token = torch.gather(probs_idx, -1, next_token) # one hot of next token, ordered by vocab
        else:
            next_token = torch.argmax(logits, dim=-1)
        next_token = next_token.reshape(-1)
        return next_token

class OpenaiGenerator(WmGenerator):
    """ Generate text using LLaMA and Aaronson's watermarking method. """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)        

    def sample_next(
        self,
        logits: torch.FloatTensor, # (bsz, vocab_size): logits for last token
        ngram_tokens: torch.LongTensor, # (bsz, ngram): tokens to consider when seeding
        temperature: float = 0.8, # temperature for sampling
        top_p: float = 0.95, # top p for sampling
    ) -> torch.LongTensor:
        """
        From ngram tokens, select the next token based on the following:
        - hash the ngram tokens and get a seed
        - use the seed to generate V random number r between [0,1]
        - select argmax ( r^(1/p) )
        payload (the message) is encoded by shifting the secret vector r by `payload`.
        """
        if temperature > 0:
            probs = torch.softmax(logits / temperature, dim=-1)
            probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
            probs_sum = torch.cumsum(probs_sort, dim=-1)
            mask = probs_sum - probs_sort > top_p
            probs_sort[mask] = 0.0
            probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
            for ii in range(ngram_tokens.shape[0]): # batch of texts
                # seed with hash of ngram tokens
                if "mkey" in self.usermode or "rkey" in self.usermode:
                    key01 = self.key01
                    if "seedbsz" in self.usermode:
                        key01 = self.key01[ii]
                    keyid = self.get_seed_rng(ngram_tokens[ii], salt_key=0) % len(key01)
                    if "randsalt" in self.usermode:
                        keyid = random.choice(range(len(key01)))
                    if key01[keyid] == "0":
                        self.salt_key = -1
                    else:
                        self.salt_key = keyid
                if "chars" in self.usermode:
                    pos = self.get_seed_rng(ngram_tokens[ii], salt_key=0) % len(self.chars)
                    #keyid = self.get_seed_rng(ngram_tokens[ii], salt_key=self.chars[pos])%256
                    self.salt_key = self.chars[pos]
                if "nchars" in self.usermode:
                    key01 = self.key01
                    if "seedbsz" in self.usermode:
                        key01 = self.key01[ii]
                    self.chars = [ord(char) for char in key01]
                    if "encext" in self.usermode:
                        a = 1
                    pos = self.get_seed_rng(ngram_tokens[ii], salt_key=0) % len(self.chars)
                    self.salt_key = self.chars[pos]
                if "brick3" in self.usermode:
                    key01 = self.key01
                    if "seedbsz" in self.usermode:
                        key01 = self.key01[ii]
                    self.chars = [ord(char) for char in key01]
                    uid = 0
                    for c in self.chars:
                        uid = uid * 256 + c
                    self.salt_key = uid
                if "mix" in self.usermode:
                    if self.key01[ii] is None:
                        continue
                    else:
                        self.salt_key = self.key01[ii]
                        if "alt2" in self.usermode:
                            pos = self.get_seed_rng(ngram_tokens[ii], salt_key=0) % 2
                            if "posenc" in self.usermode:
                                #print("generator: ", self.cur_pos)
                                pos = self.cur_pos % 2
                            if "unb" in self.usermode:
                                pos = self.get_seed_rng(ngram_tokens[ii], salt_key=0) % self.usermode["unb"]
                                pos = int(pos > 0)
                            if "r" in self.usermode:
                                s = (self.get_seed_rng(ngram_tokens[ii], salt_key=0) % 10) / 10
                                pos = int(s > self.usermode["r"])
                            self.salt_key = self.key01[ii][pos]
                        elif "alt" in self.usermode:
                            pos = self.get_seed_rng(ngram_tokens[ii], salt_key=0) % len(self.key01[ii])
                            self.salt_key = self.key01[ii][pos]
                            """if pos == len(self.key01[ii]) - 1:
                                self.salt_key = self.key01[ii][pos]
                            else:
                                self.salt_key = self.key01[ii][pos] if pos % 2 == 0 else 0"""

                if "gaussian" in self.usermode:
                    seed = self.get_seed_rng(ngram_tokens[ii])
                    self.rng.manual_seed(seed)
                    rs = torch.randn(self.tokenizer.vocab_size, generator=self.rng) 
                    rs = rs.roll(-self.payload)
                    rs = torch.Tensor(rs).to(probs_sort.device)
                    rs = rs[probs_idx[ii]] 
                    probs_sort[ii] = probs_sort[ii].log() + rs
                    continue
                    
                seed = self.get_seed_rng(ngram_tokens[ii])
                self.rng.manual_seed(seed)
                # generate rs randomly between [0,1]
                rs = torch.rand(self.tokenizer.vocab_size, generator=self.rng) # n
                rs = rs.roll(-self.payload)
                rs = torch.Tensor(rs).to(probs_sort.device)
                rs = rs[probs_idx[ii]] 
                # compute r^(1/p)
                probs_sort[ii] = torch.pow(rs, 1/probs_sort[ii])
            # select argmax ( r^(1/p) )
            next_token = torch.argmax(probs_sort, dim=-1, keepdim=True)
            next_token = torch.gather(probs_idx, -1, next_token)
            #if "posenc" in self.usermode:
            #    print(next_token[0][0].item())
            """
            import os
            for ii in range(ngram_tokens.shape[0]):
                l = [keyid, ngram_tokens[ii].tolist(), next_token[ii].item()]
                os.system(f"echo '{str(l)}' >> output/exps/$usermode/gen.jsonl")#"""
        else:
            next_token = torch.argmax(logits, dim=-1)
        next_token = next_token.reshape(-1)
        return next_token

class MarylandGenerator(WmGenerator):
    """ Generate text using LLaMA and Maryland's watemrarking method. """
    def __init__(self, 
            *args, 
            gamma: float = 0.5,
            delta: float = 1.0,
            **kwargs
        ):
        super().__init__(*args, **kwargs)        
        self.gamma = gamma
        self.delta = delta

    def sample_next(
        self,
        logits: torch.FloatTensor, # (bsz, vocab_size): logits for last token
        ngram_tokens: torch.LongTensor, # (bsz, ngram): tokens to consider when seeding
        temperature: float = 0.8, # temperature for sampling
        top_p: float = 0.95, # top p for sampling
    ) -> torch.LongTensor:
        """
        From ngram tokens, select the next token based on the following:
        - hash the ngram tokens and get a seed
        - use the seed to partition the vocabulary into greenlist (gamma*V words) and blacklist 
        - add delta to greenlist words' logits
        payload (the message) is encoded by shifting the secret vector r by `payload`.
        """
        logits = self.logits_processor(logits, ngram_tokens)
        if temperature > 0:
            probs = torch.softmax(logits / temperature, dim=-1)
            probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
            probs_sum = torch.cumsum(probs_sort, dim=-1)
            mask = probs_sum - probs_sort > top_p
            probs_sort[mask] = 0.0
            probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
            next_token = torch.multinomial(probs_sort, num_samples=1) # one hot of next token, ordered by original probs
            next_token = torch.gather(probs_idx, -1, next_token) # one hot of next token, ordered by vocab
        else:
            next_token = torch.argmax(logits, dim=-1)
        next_token = next_token.reshape(-1)
        return next_token

    def logits_processor(self, logits, ngram_tokens):
        """Process logits to mask out words in greenlist."""
        bsz, vocab_size = logits.shape
        logits = logits.clone()
        for ii in range(ngram_tokens.shape[0]): # batch of texts
            if "mix" in self.usermode:
                if self.key01[ii] is None:
                    continue
                else:
                    self.salt_key = self.key01[ii]
                    if "alt2" in self.usermode:
                        pos = self.get_seed_rng(ngram_tokens[ii], salt_key=0) % 2
                        if "posenc" in self.usermode:
                            #print("generator: ", self.cur_pos)
                            pos = self.cur_pos % 2
                        if "unb" in self.usermode:
                            pos = self.get_seed_rng(ngram_tokens[ii], salt_key=0) % self.usermode["unb"]
                            pos = int(pos > 0)
                        if "r" in self.usermode:
                            s = (self.get_seed_rng(ngram_tokens[ii], salt_key=0) % 10) / 10
                            pos = int(s > self.usermode["r"])
                        self.salt_key = self.key01[ii][pos]
            seed = self.get_seed_rng(ngram_tokens[ii])
            self.rng.manual_seed(seed)
            vocab_permutation = torch.randperm(vocab_size, generator=self.rng)
            greenlist = vocab_permutation[:int(self.gamma * vocab_size)] # gamma * n
            bias = torch.zeros(vocab_size).to(logits.device) # n
            bias[greenlist] = self.delta
            bias = bias.roll(-self.payload)
            logits[ii] += bias # add bias to greenlist words
        return logits
